import torch
import torchvision.transforms.v2 as T


class CustomColorJitter:
    def __init__(self, strength=1.0, brightness=0.8, contrast=0.8):

        self.color_jitter = T.ColorJitter(
            brightness=brightness * strength, 
            contrast=contrast * strength
        )

    def __call__(self, img: torch.Tensor):
        return self.color_jitter(img.unsqueeze(1)).squeeze(1)